# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from copy import deepcopy
from datetime import date
from io import StringIO
from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal

from mne import pick_types, read_cov, read_evokeds
from mne._fiff.pick import _picks_by_type
from mne.epochs import make_fixed_length_epochs
from mne.fixes import _eye_array
from mne.io import read_raw_fif
from mne.time_frequency import tfr_morlet
from mne.utils import (
    _PCA,
    _apply_scaling_array,
    _apply_scaling_cov,
    _array_equal_nan,
    _custom_lru_cache,
    _date_to_julian,
    _freq_mask,
    _get_inst_data,
    _julian_to_date,
    _reg_pinv,
    _replace_md5,
    _ReuseCycle,
    _time_mask,
    _undo_scaling_array,
    _undo_scaling_cov,
    compute_corr,
    create_slices,
    grand_average,
    hashfunc,
    numerics,
    object_diff,
    object_hash,
    object_size,
    random_permutation,
    sum_squared,
)
from mne.utils.numerics import _LRU_CACHE_MAXSIZES, _LRU_CACHES

base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
fname_raw = base_dir / "test_raw.fif"
ave_fname = base_dir / "test-ave.fif"
cov_fname = base_dir / "test-cov.fif"


def test_get_inst_data():
    """Test _get_inst_data."""
    raw = read_raw_fif(fname_raw)
    raw.crop(tmax=1.0)
    assert_array_equal(_get_inst_data(raw), raw._data)
    raw.pick(raw.ch_names[:2])

    epochs = make_fixed_length_epochs(raw, 0.5)
    assert_array_equal(_get_inst_data(epochs), epochs._data)

    evoked = epochs.average()
    assert_array_equal(_get_inst_data(evoked), evoked.data)

    evoked.crop(tmax=0.1)
    picks = list(range(2))
    freqs = [50.0, 55.0]
    n_cycles = 3
    tfr = tfr_morlet(evoked, freqs, n_cycles, return_itc=False, picks=picks)
    assert_array_equal(_get_inst_data(tfr), tfr.data)

    pytest.raises(TypeError, _get_inst_data, "foo")


def test_hashfunc(tmp_path):
    """Test md5/sha1 hash calculations."""
    fname1 = tmp_path / "foo"
    fname2 = tmp_path / "bar"
    with open(fname1, "wb") as fid:
        fid.write(b"abcd")
    with open(fname2, "wb") as fid:
        fid.write(b"efgh")

    for hash_type in ("md5", "sha1"):
        hash1 = hashfunc(fname1, hash_type=hash_type)
        hash1_ = hashfunc(fname1, 1, hash_type=hash_type)

        hash2 = hashfunc(fname2, hash_type=hash_type)
        hash2_ = hashfunc(fname2, 1024, hash_type=hash_type)

        assert hash1 == hash1_
        assert hash2 == hash2_
        assert hash1 != hash2


def test_sum_squared():
    """Test optimized sum of squares."""
    X = np.random.RandomState(0).randint(0, 50, (3, 3))
    assert np.sum(X**2) == sum_squared(X)


def test_compute_corr():
    """Test Anscombe's Quartett."""
    x = np.array([10, 8, 13, 9, 11, 14, 6, 4, 12, 7, 5])
    y = np.array(
        [
            [8.04, 6.95, 7.58, 8.81, 8.33, 9.96, 7.24, 4.26, 10.84, 4.82, 5.68],
            [9.14, 8.14, 8.74, 8.77, 9.26, 8.10, 6.13, 3.10, 9.13, 7.26, 4.74],
            [7.46, 6.77, 12.74, 7.11, 7.81, 8.84, 6.08, 5.39, 8.15, 6.42, 5.73],
            [8, 8, 8, 8, 8, 8, 8, 19, 8, 8, 8],
            [6.58, 5.76, 7.71, 8.84, 8.47, 7.04, 5.25, 12.50, 5.56, 7.91, 6.89],
        ]
    )

    r = compute_corr(x, y.T)
    r2 = np.array([np.corrcoef(x, y[i])[0, 1] for i in range(len(y))])
    assert_allclose(r, r2)
    pytest.raises(ValueError, compute_corr, [1, 2], [])


def test_create_slices():
    """Test checking the create of time create_slices."""
    # Test that create_slices default provide an empty list
    assert create_slices(0, 0) == []
    # Test that create_slice return correct number of slices
    assert len(create_slices(0, 100)) == 100
    # Test with non-zero start parameters
    assert len(create_slices(50, 100)) == 50
    # Test slices' length with non-zero start and window_width=2
    assert len(create_slices(0, 100, length=2)) == 50
    # Test slices' length with manual slice separation
    assert len(create_slices(0, 100, step=10)) == 10
    # Test slices' within length for non-consecutive samples
    assert len(create_slices(0, 500, length=50, step=10)) == 46
    # Test that slices elements start, stop and step correctly
    slices = create_slices(0, 10)
    assert slices[0].start == 0
    assert slices[0].step == 1
    assert slices[0].stop == 1
    assert slices[-1].stop == 10
    # Same with larger window width
    slices = create_slices(0, 9, length=3)
    assert slices[0].start == 0
    assert slices[0].step == 1
    assert slices[0].stop == 3
    assert slices[-1].stop == 9
    # Same with manual slices' separation
    slices = create_slices(0, 9, length=3, step=1)
    assert len(slices) == 7
    assert slices[0].step == 1
    assert slices[0].stop == 3
    assert slices[-1].start == 6
    assert slices[-1].stop == 9


def test_time_mask():
    """Test safe time masking."""
    N = 10
    x = np.arange(N).astype(float)
    assert _time_mask(x, 0, N - 1).sum() == N
    assert _time_mask(x - 1e-10, 0, N - 1, sfreq=1000.0).sum() == N
    assert _time_mask(x - 1e-10, None, N - 1, sfreq=1000.0).sum() == N
    assert _time_mask(x - 1e-10, None, None, sfreq=1000.0).sum() == N
    assert _time_mask(x - 1e-10, -np.inf, None, sfreq=1000.0).sum() == N
    assert _time_mask(x - 1e-10, None, np.inf, sfreq=1000.0).sum() == N
    # non-uniformly spaced inputs
    x = np.array([4, 10])
    assert _time_mask(x[:1], tmin=10, sfreq=1, raise_error=False).sum() == 0
    assert _time_mask(x[:1], tmin=11, tmax=12, sfreq=1, raise_error=False).sum() == 0
    assert _time_mask(x, tmin=10, sfreq=1).sum() == 1
    assert _time_mask(x, tmin=6, sfreq=1).sum() == 1
    assert _time_mask(x, tmin=5, sfreq=1).sum() == 1
    assert _time_mask(x, tmin=4.5001, sfreq=1).sum() == 1
    assert _time_mask(x, tmin=4.4999, sfreq=1).sum() == 2
    assert _time_mask(x, tmin=4, sfreq=1).sum() == 2
    # degenerate cases
    with pytest.raises(ValueError, match="No samples remain"):
        _time_mask(x[:1], tmin=11, tmax=12)
    with pytest.raises(ValueError, match="must be less than or equal to tmax"):
        _time_mask(x[:1], tmin=10, sfreq=1)


def test_freq_mask():
    """Test safe frequency masking."""
    N = 10
    x = np.arange(N).astype(float)
    assert _freq_mask(x, 1000.0, fmin=0, fmax=N - 1).sum() == N
    assert _freq_mask(x - 1e-10, 1000.0, fmin=0, fmax=N - 1).sum() == N
    assert _freq_mask(x - 1e-10, 1000.0, fmin=None, fmax=N - 1).sum() == N
    assert _freq_mask(x - 1e-10, 1000.0, fmin=None, fmax=None).sum() == N
    assert _freq_mask(x - 1e-10, 1000.0, fmin=-np.inf, fmax=None).sum() == N
    assert _freq_mask(x - 1e-10, 1000.0, fmin=None, fmax=np.inf).sum() == N
    # non-uniformly spaced inputs
    x = np.array([4, 10])
    assert _freq_mask(x[:1], 1, fmin=10, raise_error=False).sum() == 0
    assert _freq_mask(x[:1], 1, fmin=11, fmax=12, raise_error=False).sum() == 0
    assert _freq_mask(x, sfreq=1, fmin=10).sum() == 1
    assert _freq_mask(x, sfreq=1, fmin=6).sum() == 1
    assert _freq_mask(x, sfreq=1, fmin=5).sum() == 1
    assert _freq_mask(x, sfreq=1, fmin=4.5001).sum() == 1
    assert _freq_mask(x, sfreq=1, fmin=4.4999).sum() == 2
    assert _freq_mask(x, sfreq=1, fmin=4).sum() == 2
    # degenerate cases
    with pytest.raises(ValueError, match="sfreq can not be None"):
        _freq_mask(x[:1], sfreq=None, fmin=3, fmax=5)
    with pytest.raises(ValueError, match="No frequencies remain"):
        _freq_mask(x[:1], sfreq=1, fmin=11, fmax=12)
    with pytest.raises(ValueError, match="must be less than or equal to fmax"):
        _freq_mask(x[:1], sfreq=1, fmin=10)


def test_random_permutation():
    """Test random permutation function."""
    n_samples = 10
    random_state = 42
    python_randperm = random_permutation(n_samples, random_state)

    # matlab output when we execute rng(42), randperm(10)
    matlab_randperm = np.array([7, 6, 5, 1, 4, 9, 10, 3, 8, 2])

    assert_array_equal(python_randperm, matlab_randperm - 1)


def test_cov_scaling():
    """Test rescaling covs."""
    evoked = read_evokeds(ave_fname, condition=0, baseline=(None, 0), proj=True)
    cov = read_cov(cov_fname)["data"]
    cov2 = read_cov(cov_fname)["data"]

    assert_array_equal(cov, cov2)
    evoked.pick(
        [evoked.ch_names[k] for k in pick_types(evoked.info, meg=True, eeg=True)]
    )
    picks_list = _picks_by_type(evoked.info)
    scalings = dict(mag=1e15, grad=1e13, eeg=1e6)

    _apply_scaling_cov(cov2, picks_list, scalings=scalings)
    _apply_scaling_cov(cov, picks_list, scalings=scalings)
    assert_array_equal(cov, cov2)
    assert cov.max() > 1

    _undo_scaling_cov(cov2, picks_list, scalings=scalings)
    _undo_scaling_cov(cov, picks_list, scalings=scalings)
    assert_array_equal(cov, cov2)
    assert cov.max() < 1

    data = evoked.data.copy()
    _apply_scaling_array(data, picks_list, scalings=scalings)
    _undo_scaling_array(data, picks_list, scalings=scalings)
    assert_allclose(data, evoked.data, atol=1e-20)


@pytest.mark.parametrize("ndim", (2, 3))
def test_reg_pinv(ndim):
    """Test regularization and inversion of covariance matrix."""
    # create rank-deficient array
    a = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
    for _ in range(ndim - 2):
        a = a[np.newaxis]

    # Test if rank-deficient matrix without regularization throws
    # specific warning
    with pytest.warns(RuntimeWarning, match="deficient"):
        _reg_pinv(a, reg=0.0)

    # Test inversion with explicit rank
    a_inv_np = np.linalg.pinv(a, hermitian=True)
    a_inv_mne, loading_factor, rank = _reg_pinv(a, rank=2)
    assert loading_factor == 0
    assert rank == 2
    assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)

    # Test inversion with automatic rank detection
    a_inv_mne, _, estimated_rank = _reg_pinv(a, rank=None)
    assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)
    assert estimated_rank == 2

    # Test adding regularization
    a_inv_mne, loading_factor, estimated_rank = _reg_pinv(a, reg=2)
    # Since A has a diagonal of all ones, loading_factor should equal the
    # regularization parameter
    assert loading_factor == 2
    # The estimated rank should be that of the non-regularized matrix
    assert estimated_rank == 2
    # Test result against the NumPy version
    a_inv_np = np.linalg.pinv(a + loading_factor * np.eye(3), hermitian=True)
    assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)

    # Test setting rcond
    a_inv_np = np.linalg.pinv(a, rcond=0.5)
    a_inv_mne, _, estimated_rank = _reg_pinv(a, rcond=0.5)
    assert_allclose(a_inv_np, a_inv_mne, atol=1e-14)
    assert estimated_rank == 1

    # Test inverting an all zero cov
    a_inv, loading_factor, estimated_rank = _reg_pinv(np.zeros((3, 3)), reg=2)
    assert_array_equal(a_inv, 0)
    assert loading_factor == 0
    assert estimated_rank == 0


def test_object_size():
    """Test object size estimation."""
    assert object_size(np.ones(10, np.float32)) < object_size(np.ones(10, np.float64))
    for lower, upper, obj in (
        (0, 60, ""),
        (0, 30, 1),
        (0, 30, 1.0),
        (0, 70, "foo"),
        (0, 150, np.ones(0)),
        (0, 150, np.int32(1)),
        (150, 500, np.ones(20)),
        (30, 400, dict()),
        (400, 1000, dict(a=np.ones(50))),
        (200, 900, _eye_array(20, format="csc")),
        (200, 900, _eye_array(20, format="csr")),
    ):
        size = object_size(obj)
        assert lower < size < upper, f"{lower} < {size} < {upper}:\n{obj}"
    # views work properly
    x = dict(a=1)
    assert object_size(x) < 1000
    x["a"] = np.ones(100000, float)
    nb = x["a"].nbytes
    sz = object_size(x)
    assert nb < sz < nb * 1.01
    x["b"] = x["a"]
    sz = object_size(x)
    assert nb < sz < nb * 1.01
    x["b"] = x["a"].view()
    x["b"].flags.writeable = False
    assert x["a"].flags.writeable
    sz = object_size(x)
    assert nb < sz < nb * 1.01


def test_object_diff_with_nan():
    """Test object diff can handle NaNs."""
    d0 = np.array([1, np.nan, 0])
    d1 = np.array([1, np.nan, 0])
    d2 = np.array([np.nan, 1, 0])

    assert object_diff(d0, d1) == ""
    assert object_diff(d0, d2) != ""
    assert object_diff(np.nan, np.nan) == ""
    assert object_diff(np.nan, 3.5) == " value mismatch (nan, 3.5)\n"


def test_hash():
    """Test dictionary hashing and comparison functions."""
    # does hashing all of these types work:
    # {dict, list, tuple, ndarray, str, float, int, None}
    d0 = dict(a=dict(a=0.1, b="fo", c=1), b=[1, "b"], c=(), d=np.ones(3), e=None)
    d0[1] = None
    d0[2.0] = b"123"

    d1 = deepcopy(d0)
    assert len(object_diff(d0, d1)) == 0
    assert len(object_diff(d1, d0)) == 0
    assert object_hash(d0) == object_hash(d1)

    # change values slightly
    d1["data"] = np.ones(3, int)
    d1["d"][0] = 0
    assert object_hash(d0) != object_hash(d1)

    d1 = deepcopy(d0)
    assert object_hash(d0) == object_hash(d1)
    d1["a"]["a"] = 0.11
    assert len(object_diff(d0, d1)) > 0
    assert len(object_diff(d1, d0)) > 0
    assert object_hash(d0) != object_hash(d1)

    d1 = deepcopy(d0)
    assert object_hash(d0) == object_hash(d1)
    d1["a"]["d"] = 0  # non-existent key
    assert len(object_diff(d0, d1)) > 0
    assert len(object_diff(d1, d0)) > 0
    assert object_hash(d0) != object_hash(d1)

    d1 = deepcopy(d0)
    assert object_hash(d0) == object_hash(d1)
    d1["b"].append(0)  # different-length lists
    assert len(object_diff(d0, d1)) > 0
    assert len(object_diff(d1, d0)) > 0
    assert object_hash(d0) != object_hash(d1)

    d1 = deepcopy(d0)
    assert object_hash(d0) == object_hash(d1)
    d1["e"] = "foo"  # non-None
    assert len(object_diff(d0, d1)) > 0
    assert len(object_diff(d1, d0)) > 0
    assert object_hash(d0) != object_hash(d1)

    d1 = deepcopy(d0)
    d2 = deepcopy(d0)
    d1["e"] = StringIO()
    d2["e"] = StringIO()
    d2["e"].write("foo")
    assert len(object_diff(d0, d1)) > 0
    assert len(object_diff(d1, d0)) > 0

    d1 = deepcopy(d0)
    d1[1] = 2
    assert len(object_diff(d0, d1)) > 0
    assert len(object_diff(d1, d0)) > 0
    assert object_hash(d0) != object_hash(d1)

    # generators (and other types) not supported
    d1 = deepcopy(d0)
    d2 = deepcopy(d0)
    d1[1] = (x for x in d0)
    d2[1] = (x for x in d0)
    pytest.raises(RuntimeError, object_diff, d1, d2)
    pytest.raises(RuntimeError, object_hash, d1)

    x = _eye_array(2, format="csc")
    y = _eye_array(2, format="csr")
    assert "type mismatch" in object_diff(x, y)
    y = _eye_array(2, format="csc")
    assert len(object_diff(x, y)) == 0
    y[1, 1] = 2
    assert "elements" in object_diff(x, y)
    y = _eye_array(3, format="csc")
    assert "shape" in object_diff(x, y)
    y = 0
    assert "type mismatch" in object_diff(x, y)

    # smoke test for gh-4796
    assert object_hash(np.int64(1)) != 0
    assert object_hash(np.bool_(True)) != 0


@pytest.mark.parametrize("n_components", (None, 0.9999, 8, "mle"))
@pytest.mark.parametrize("whiten", (True, False))
def test_pca(n_components, whiten):
    """Test PCA equivalence."""
    pytest.importorskip("sklearn")
    from sklearn.decomposition import PCA

    n_samples, n_dim = 1000, 10
    X = np.random.RandomState(0).randn(n_samples, n_dim)
    X[:, -1] = np.mean(X[:, :-1], axis=-1)  # true X dim is ndim - 1
    X_orig = X.copy()
    pca_skl = PCA(n_components, whiten=whiten, svd_solver="full")
    pca_mne = _PCA(n_components, whiten=whiten)
    X_skl = pca_skl.fit_transform(X)
    assert_array_equal(X, X_orig)
    X_mne = pca_mne.fit_transform(X)
    assert_array_equal(X, X_orig)
    assert_allclose(X_skl, X_mne * np.sign(np.sum(X_skl * X_mne, axis=0)))
    assert pca_mne.n_components_ == pca_skl.n_components_
    for key in (
        "mean_",
        "components_",
        "explained_variance_",
        "explained_variance_ratio_",
    ):
        val_skl, val_mne = getattr(pca_skl, key), getattr(pca_mne, key)
        if key == "components_":
            val_mne = val_mne * np.sign(
                np.sum(val_skl * val_mne, axis=1, keepdims=True)
            )
        assert_allclose(val_skl, val_mne)
    if isinstance(n_components, float):
        assert pca_mne.n_components_ == n_dim - 1
    elif isinstance(n_components, int):
        assert pca_mne.n_components_ == n_components
    elif n_components == "mle":
        assert pca_mne.n_components_ == n_dim - 1
    else:
        assert n_components is None
        assert pca_mne.n_components_ == n_dim


def test_array_equal_nan():
    """Test comparing arrays with NaNs."""
    a = b = [1, np.nan, 0]
    assert not np.array_equal(a, b)  # this is the annoying behavior we avoid
    assert _array_equal_nan(a, b)
    b = [np.nan, 1, 0]
    assert not _array_equal_nan(a, b)
    a = b = [np.nan] * 2
    assert _array_equal_nan(a, b)


def test_julian_conversions():
    """Test julian calendar conversions."""
    # https://aa.usno.navy.mil/data/docs/JulianDate.php
    # A.D. 1922 Jun 13  12:00:00.0  2423219.000000
    # A.D. 2018 Oct 3   12:00:00.0  2458395.000000

    jds = [2423219, 2458395, 2445701]
    cals = [(1922, 6, 13), (2018, 10, 3), (1984, 1, 1)]
    dds = [date(*c) for c in cals]

    for dd, cal, jd in zip(dds, cals, jds):
        assert dd == _julian_to_date(jd)
        assert jd == _date_to_julian(dd)


def test_grand_average_empty_sequence():
    """Test if mne.grand_average handles an empty sequence correctly."""
    with pytest.raises(ValueError, match="Please pass a list of Evoked"):
        grand_average([])


def test_grand_average_len_1():
    """Test if mne.grand_average handles a sequence of length 1 correctly."""
    # returns a list of length 1
    evokeds = read_evokeds(ave_fname, condition=[0], proj=True)

    with pytest.warns(RuntimeWarning, match="Only a single dataset"):
        gave = grand_average(evokeds)

    assert_allclose(gave.data, evokeds[0].data)


def test_reuse_cycle():
    """Test _ReuseCycle."""
    vals = "abcde"
    iterable = _ReuseCycle(vals)
    assert "".join(next(iterable) for _ in range(2 * len(vals))) == vals + vals
    # we're back to initial
    assert "".join(next(iterable) for _ in range(2)) == "ab"
    iterable.restore("a")
    assert "".join(next(iterable) for _ in range(10)) == "acdeabcdea"
    assert "".join(next(iterable) for _ in range(4)) == "bcde"
    # we're back to initial
    assert "".join(next(iterable) for _ in range(3)) == "abc"
    iterable.restore("a")
    iterable.restore("b")
    iterable.restore("c")
    assert "".join(next(iterable) for _ in range(5)) == "abcde"
    # we're back to initial
    assert "".join(next(iterable) for _ in range(3)) == "abc"
    iterable.restore("a")
    iterable.restore("c")
    assert "".join(next(iterable) for _ in range(4)) == "acde"
    assert "".join(next(iterable) for _ in range(5)) == "abcde"
    # we're back to initial
    assert "".join(next(iterable) for _ in range(3)) == "abc"
    iterable.restore("c")
    iterable.restore("a")
    with pytest.warns(RuntimeWarning, match="Could not find"):
        iterable.restore("a")
    assert "".join(next(iterable) for _ in range(4)) == "acde"
    assert "".join(next(iterable) for _ in range(5)) == "abcde"


@pytest.mark.parametrize("n", (0, 1, 10, 1000))
@pytest.mark.parametrize("d", (0.0001, 1, 2.5, 1000))
def test_arange_div(numba_conditional, n, d):
    """Test Numba arange_div."""
    want = np.arange(n) / d
    got = numerics._arange_div(n, d)
    assert_allclose(got, want)


def test_custom_lru_cache():
    """Test our _custom_lru_cache implementation."""
    n_calls = [0, 0]
    start_size = len(_LRU_CACHES)

    @_custom_lru_cache(2)
    def my_fun(*args):
        n_calls[0] += 1
        return ", ".join(arg.__class__.__name__ for arg in args)

    assert len(_LRU_CACHES) == start_size + 1
    fun_hash = list(_LRU_CACHES)[-1]
    assert _LRU_CACHE_MAXSIZES[fun_hash] == 2

    @_custom_lru_cache(1)
    def my_fun_2(*args):
        n_calls[1] += 1
        return ", ".join(arg.__class__.__name__ for arg in args)

    assert len(_LRU_CACHES) == start_size + 2
    fun_2_hash = list(_LRU_CACHES)[-1]
    assert _LRU_CACHE_MAXSIZES[fun_2_hash] == 1

    assert n_calls == [0, 0]
    assert my_fun(1, 2, 3) == "int, int, int"
    assert n_calls == [1, 0]
    assert my_fun_2(1, 2, 3.0) == "int, int, float"
    assert n_calls == [1, 1]
    # repeated calls use cached version
    assert my_fun(1, 2, 3) == "int, int, int"
    assert n_calls == [1, 1]
    assert my_fun_2(1, 2, 3.0) == "int, int, float"
    assert n_calls == [1, 1]
    assert len(_LRU_CACHES[fun_hash]) == 1
    assert len(_LRU_CACHES[fun_2_hash]) == 1
    assert my_fun(1, np.array([2]), 3) == "int, ndarray, int"
    assert n_calls == [2, 1]
    assert len(_LRU_CACHES[fun_hash]) == 2
    assert my_fun_2(1, _eye_array(1, format="csc")) == "int, csc_array"
    assert n_calls == [2, 2]
    assert len(_LRU_CACHES[fun_2_hash]) == 1  # other got popped
    # we could add support for this eventually, but don't bother for now
    with pytest.raises(RuntimeError, match="Unsupported sparse type"):
        my_fun_2(1, _eye_array(1, format="coo"))
    assert n_calls == [2, 2]  # never did any computation


def test_replace_md5(tmp_path):
    """Test _replace_md5."""
    old = tmp_path / "test"
    new = old.with_suffix(".new")
    old.write_text("abcd")
    new.write_text("abcde")
    assert old.is_file()
    assert new.is_file()
    _replace_md5(str(new))
    assert not new.is_file()
    assert old.read_text() == "abcde"
    new.write_text(old.read_text())
    _replace_md5(str(new))
    assert old.read_text() == "abcde"
    assert not new.is_file()
